<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
    <title>WebGPU Compute Shaders - Histogram</title>
    <style>
      @import url(resources/webgpu-lesson.css);
      canvas {
        display: block;
        max-width: 256px;
        border: 1px solid #888;
        background-color: #333;
      }
    </style>
  </head>
  <body>
    <canvas></canvas>
  </body>
  <script type="module">
import TimingHelper from './resources/js/timing-helper.js';
// see https://webgpufundamentals.org/webgpu/lessons/webgpu-utils.html#webgpu-utils
import {
  loadImageBitmap,
  createTextureFromSource,
} from '../3rdparty/webgpu-utils-1.x.module.js';

async function main() {
  const adapter = await navigator.gpu?.requestAdapter();
  const canTimestamp = adapter?.features.has('timestamp-query');
  const device = await adapter?.requestDevice({
    requiredFeatures: [
      ...(canTimestamp ? ['timestamp-query'] : []),
    ],
  });
  if (!device) {
    fail('need a browser that supports WebGPU');
    return;
  }

  const timingHelper = new TimingHelper(device);
  const timingHelper2 = new TimingHelper(device);

  // Get a WebGPU context from the canvas and configure it
  const canvas = document.querySelector('canvas');
  const context = canvas.getContext('webgpu');
  const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
  context.configure({
    device,
    format: presentationFormat,
  });

  const histogramChunkModule = device.createShaderModule({
    label: 'histogram chunk shader',
    code: /* wgsl */ `
      var<workgroup> histogram: array<array<atomic<u32>, 4>, 256>;
      @group(0) @binding(0) var<storage, read_write> histogramChunks: array<array<vec4u, 256>>;
      @group(0) @binding(1) var ourTexture: texture_2d<f32>;

      const kSRGBLuminanceFactors = vec3f(0.2126, 0.7152, 0.0722);
      fn srgbLuminance(color: vec3f) -> f32 {
        return saturate(dot(color, kSRGBLuminanceFactors));
      }

      
      @compute @workgroup_size(16, 16, 1) fn cs(
        @builtin(workgroup_id) workgroup_id: vec3u,
        @builtin(local_invocation_id) local_invocation_id: vec3u,
      ) {
        let size = textureDimensions(ourTexture, 0);
        let position = workgroup_id.xy * vec2u(16,16) + 
                       local_invocation_id.xy;
        if (all(position < size)) {
          let numBins = f32(256); //arrayLength(&histogram));
          let lastBinIndex = u32(numBins - 1);
          var channels = textureLoad(ourTexture, position, 0);
          channels.a = srgbLuminance(channels.rgb);
          for (var ch = 0; ch < 4; ch++) {
            let ndx = min(u32(channels[ch] * numBins), lastBinIndex);
            atomicAdd(&histogram[ndx][ch], 1u);
          }
        }

        workgroupBarrier();

        let chunksAcross = (size.x + 255) / 256;
        let chunk = workgroup_id.y * chunksAcross + workgroup_id.x;

        let ndx = local_invocation_id.y * 16 + local_invocation_id.x;
        histogramChunks[chunk][ndx] = vec4u(
          atomicLoad(&histogram[ndx][0]),
          atomicLoad(&histogram[ndx][1]),
          atomicLoad(&histogram[ndx][2]),
          atomicLoad(&histogram[ndx][3]),
        );
      }
    `,
  });

  const chunkSumModule = device.createShaderModule({
    label: 'chunk sum shader',
    code: /* wgsl */ `
      @group(0) @binding(0) var<storage, read> histogramChunks: array<array<vec4u, 256>>;
      @group(0) @binding(1) var<storage, read_write> histogram: array<vec4u, 256>;

      @compute @workgroup_size(256, 1, 1) fn cs(
        @builtin(local_invocation_id) local_invocation_id: vec3u,
      ) {
        var sum = vec4u(0);
        let numChunks = arrayLength(&histogramChunks);
        for (var i = 0u; i < numChunks; i++) {
          sum += histogramChunks[i][local_invocation_id.x];
        }
        histogram[local_invocation_id.x] = sum;
      }
    `,
  });

  const histogramChunkPipeline = device.createComputePipeline({
    label: 'histogram',
    layout: 'auto',
    compute: {
      module: histogramChunkModule,
    },
  });

  const chunkSumPipeline = device.createComputePipeline({
    label: 'chunk sum',
    layout: 'auto',
    compute: {
      module: chunkSumModule,
    },
  });

  const imgBitmap = await loadImageBitmap('resources/images/pexels-chevanon-photography-1108099.jpg'); /* webgpufundamentals: url */
  const texture = createTextureFromSource(device, imgBitmap);

  const histogramBuffer = device.createBuffer({
    size: 256 * 4 * 4, // 256 entries * 4 (rgba) * 4 bytes per (u32)
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
  });

  const chunkWidth = 256;
  const chunkHeight = 1;
  const chunksAcross = Math.ceil(texture.width / chunkWidth);
  const chunksDown = Math.ceil(texture.height / chunkHeight);
  const numChunks = chunksAcross * chunksDown;
  const chunksBuffer = device.createBuffer({
    size: numChunks * 256 * 4 * 4,
    usage: GPUBufferUsage.STORAGE,
  });

  const resultBuffer = device.createBuffer({
    size: histogramBuffer.size,
    usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
  });

  const histogramBindGroup = device.createBindGroup({
    layout: histogramChunkPipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: chunksBuffer }},
      { binding: 1, resource: texture.createView() },
    ],
  });

  const chunkSumBindGroup = device.createBindGroup({
    layout: chunkSumPipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: chunksBuffer }},
      { binding: 1, resource: { buffer: histogramBuffer }},
    ],
  });

  const encoder = device.createCommandEncoder({
    label: 'histogram encoder',
  });

  {
  const pass = timingHelper.beginComputePass(encoder);
  pass.setPipeline(histogramChunkPipeline);
  pass.setBindGroup(0, histogramBindGroup);
  pass.dispatchWorkgroups(chunksAcross, chunksDown);
  pass.end();
  }

  {
  const pass = timingHelper2.beginComputePass(encoder);
  pass.setPipeline(chunkSumPipeline);
  pass.setBindGroup(0, chunkSumBindGroup);
  pass.dispatchWorkgroups(1);
  pass.end();
  }

  encoder.copyBufferToBuffer(histogramBuffer, 0, resultBuffer, 0, resultBuffer.size);

  const commandBuffer = encoder.finish();
  device.queue.submit([commandBuffer]);

  timingHelper.getResult().then(duration => {
    console.log(`duration 1: ${duration}ns`);
  });
  timingHelper2.getResult().then(duration => {
    console.log(`duration 2: ${duration}ns`);
  });

  showImageBitmap(imgBitmap);

  await resultBuffer.mapAsync(GPUMapMode.READ);
  const result = new Uint32Array(resultBuffer.getMappedRange());

  const to3 = v => v.toString().padStart(3);

  const sum = [0, 0, 0, 0];
  for (let i = 0; i < 256; ++i) {
    const off = i * 4;
    console.log(to3(i), to3(result[off]), to3(result[off + 1]), to3(result[off + 2]), to3(result[off + 3]));
    for (let j = 0; j < 4; ++j) {
      sum[j] += result[off + j];
    }
  }
  console.log('sum:', sum);
}

function showImageBitmap(imageBitmap) {
  const canvas = document.createElement('canvas');

  // we have to see the canvas size because of a firefox bug
  // https://bugzilla.mozilla.org/show_bug.cgi?id=1850871
  canvas.width = imageBitmap.width;
  canvas.height = imageBitmap.height;

  const bm = canvas.getContext('bitmaprenderer');
  bm.transferFromImageBitmap(imageBitmap);
  document.body.appendChild(canvas);
}

function fail(msg) {
  // eslint-disable-next-line no-alert
  alert(msg);
}

main();
  </script>
</html>
